{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dCpvgG0vwXAZ"
   },
   "source": [
    "## Predicting patient risk to develop Pancreatic Cancer with Med-BERT finetuned model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hsZvic2YxnTz"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0510 22:52:33.998448 140259224065792 file_utils.py:35] PyTorch version 1.2.0 available.\n"
     ]
    }
   ],
   "source": [
    "### Required Packages\n",
    "from termcolor import colored\n",
    "import math\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "import random\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "import pickle as pkl\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from torch import optim\n",
    "import tqdm\n",
    "import time\n",
    "import transformers\n",
    "\n",
    "from sklearn.metrics import roc_auc_score  \n",
    "from sklearn.metrics import roc_curve \n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.pyplot import cm\n",
    "%matplotlib inline\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "import transformers\n",
    "from transformers import BertForSequenceClassification\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Data Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "### The version in this file is updated to dump pt_id for better visualization and results analysis\n",
    "\n",
    "### Below are key functions for  Data prepartion ,formating input data into features, and model defintion \n",
    "class PaddingInputExample(object):\n",
    "  \"\"\"Fake example so the num input examples is a multiple of the batch size.\n",
    "  Based on original BERT code: We use this class instead of `None` because treating `None` as padding\n",
    "  batches could cause silent errors.\n",
    "  \"\"\"\n",
    "\n",
    "class InputFeatures(object):\n",
    "  \"\"\"A single set of features of data.\"\"\"\n",
    "\n",
    "  def __init__(self,\n",
    "               input_ids,\n",
    "               input_mask,\n",
    "               segment_ids,\n",
    "               label_id,\n",
    "               pt_id,\n",
    "               is_real_example=True):\n",
    "    self.input_ids = input_ids\n",
    "    self.input_mask = input_mask\n",
    "    self.segment_ids = segment_ids\n",
    "    self.label_id = label_id\n",
    "    self.is_real_example = is_real_example\n",
    "    self.pt_id = pt_id\n",
    "    \n",
    "    \n",
    "def convert_EHRexamples_to_features(examples,max_seq_length):\n",
    "    \"\"\"Convert a set of `InputExample`s to a list of `InputFeatures`.\"\"\"\n",
    "\n",
    "    features = []\n",
    "    for (ex_index, example) in enumerate(examples):\n",
    "        feature = convert_singleEHR_example(ex_index, example, max_seq_length)\n",
    "        features.append(feature)\n",
    "    return features\n",
    "\n",
    "### This is the EHR version\n",
    "def convert_singleEHR_example(ex_index, example, max_seq_length):\n",
    "    if isinstance(example, PaddingInputExample):\n",
    "        return InputFeatures(\n",
    "        input_ids=[0] * max_seq_length,\n",
    "        input_mask=[0] * max_seq_length,\n",
    "        segment_ids=[0] * max_seq_length,\n",
    "        label_id=0,\n",
    "        is_real_example=False)\n",
    "    \n",
    "    input_ids=example[2]\n",
    "    segment_ids=example[3]\n",
    "    label_id=example[1]\n",
    "    pt_id=example[0]\n",
    "    \n",
    "  # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
    "  # tokens are attended to.\n",
    "    input_mask = [1] * len(input_ids)\n",
    "\n",
    "   \n",
    "  # LR 5/13 Left Truncate longer sequence \n",
    "    while len(input_ids) > max_seq_length:\n",
    "        input_ids= input_ids[-max_seq_length:] \n",
    "        input_mask= input_mask[-max_seq_length:]\n",
    "        segment_ids= segment_ids[-max_seq_length:]   \n",
    "    \n",
    "  # Zero-pad up to the sequence length.\n",
    "    while len(input_ids) < max_seq_length:\n",
    "        input_ids.append(0)\n",
    "        input_mask.append(0)\n",
    "        segment_ids.append(0)\n",
    "\n",
    "    assert len(input_ids) == max_seq_length\n",
    "    assert len(input_mask) == max_seq_length\n",
    "    assert len(segment_ids) == max_seq_length\n",
    "\n",
    "  \n",
    "    feature =[input_ids,input_mask,segment_ids,label_id,pt_id,True]\n",
    "    return feature"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BERTdataEHR(Dataset):\n",
    "    def __init__(self, Features):\n",
    "           \n",
    "        self.data= Features\n",
    "  \n",
    "                                     \n",
    "    def __getitem__(self, idx, seeDescription = False):\n",
    "\n",
    "        sample = self.data[idx]\n",
    "   \n",
    "        return sample\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)     \n",
    "\n",
    "         \n",
    "#customized parts for EHRdataloader\n",
    "def my_collate(batch):\n",
    "        all_input_ids = []\n",
    "        all_input_mask = []\n",
    "        all_segment_ids = []\n",
    "        all_label_ids = []\n",
    "        all_pt_ids=[]\n",
    "\n",
    "        for feature in batch:\n",
    "            all_input_ids.append(feature[0])\n",
    "            all_input_mask.append(feature[1])\n",
    "            all_segment_ids.append(feature[2])\n",
    "            all_label_ids.append(feature[3])\n",
    "            all_pt_ids.append(feature[4])\n",
    "        return [[all_input_ids, all_input_mask,all_segment_ids,all_label_ids],all_pt_ids]\n",
    "            \n",
    "\n",
    "class BERTdataEHRloader(DataLoader):\n",
    "    def __init__(self, dataset, batch_size=128, shuffle=False, sampler=None, batch_sampler=None,\n",
    "                 num_workers=0, collate_fn=my_collate, pin_memory=False, drop_last=False,\n",
    "                 timeout=0, worker_init_fn=None):\n",
    "        DataLoader.__init__(self, dataset, batch_size=batch_size, shuffle=False, sampler=None, batch_sampler=None,\n",
    "                 num_workers=0, collate_fn=my_collate, pin_memory=False, drop_last=False,\n",
    "                 timeout=0, worker_init_fn=None)\n",
    "        self.collate_fn = collate_fn\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Model Definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EHR_BERT_RNN(nn.Module):\n",
    "    def __init__(self, input_size,embed_dim, hidden_size, n_layers=1,dropout_r=0.1,cell_type='LSTM',bi=False,emb=''):\n",
    "        super(EHR_BERT_RNN, self).__init__()\n",
    "        self.n_layers = n_layers\n",
    "        self.hidden_size = hidden_size\n",
    "        self.embed_dim = embed_dim\n",
    "        self.dropout_r = dropout_r\n",
    "        self.cell_type = cell_type\n",
    "        if emb=='brt':\n",
    "            self.brt=True\n",
    "            self.pretrain=False\n",
    "        elif len(emb)>3:\n",
    "            self.brt=False\n",
    "            self.pretrain=True\n",
    "            self.pretrained_emb=load_pretrain_w2vemb(emb,self.embed_dim)\n",
    "        else:\n",
    "            self.brt=False\n",
    "            self.pretrain=False\n",
    "        \n",
    "        if bi: self.bi=2 \n",
    "        else: self.bi=1\n",
    "        \n",
    "        if use_cuda:\n",
    "            self.flt_typ=torch.cuda.FloatTensor\n",
    "            self.lnt_typ=torch.cuda.LongTensor\n",
    "        else: \n",
    "            self.lnt_typ=torch.LongTensor\n",
    "            self.flt_typ=torch.FloatTensor\n",
    "        \n",
    "        if self.brt:\n",
    "            self.PreBERTmodel=BertForSequenceClassification.from_pretrained(\"./\")\n",
    "            input_size=self.PreBERTmodel.bert.config.vocab_size\n",
    "            self.in_size= self.PreBERTmodel.bert.config.hidden_size\n",
    "        \n",
    "        elif self.pretrain: ## to add W2Vec embedding as example\n",
    "            self.embed= nn.Embedding.from_pretrained(self.pretrained_emb)#,freeze=False)  \n",
    "            self.vembed= nn.Embedding(500, self.embed_dim,padding_idx=0)\n",
    "            input_size=self.pretrained_emb.shape[0]+1\n",
    "            self.in_size= self.pretrained_emb.shape[1]\n",
    "        else:\n",
    "            input_size=input_size\n",
    "            self.embed= nn.Embedding(input_size, self.embed_dim,padding_idx=0)#,scale_grad_by_freq=True)\n",
    "            self.vembed= nn.Embedding(500, self.embed_dim,padding_idx=0)\n",
    "            self.in_size= embed_dim\n",
    "            #if self.time: self.in_size= self.in_size+1  ### place holder - no time information in here yet\n",
    "               \n",
    "        if self.cell_type == \"GRU\":\n",
    "            cell = nn.GRU\n",
    "        elif self.cell_type == \"RNN\":\n",
    "            cell = nn.RNN\n",
    "        elif self.cell_type == \"LSTM\":\n",
    "            cell = nn.LSTM\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "        \n",
    "        self.dropout = nn.Dropout(p=self.dropout_r)\n",
    "        #self.embed_size = embed_size\n",
    "        self.rnn_c = cell(self.in_size, hidden_size,num_layers=n_layers, dropout= dropout_r , bidirectional=bi , batch_first=True)\n",
    "        self.out = nn.Linear(self.hidden_size*self.bi,1)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "\n",
    "    def forward(self, sequence):\n",
    "\n",
    "        token_t=torch.from_numpy(np.asarray(sequence[0],dtype=int)).type(self.lnt_typ)\n",
    "        seg_t=torch.from_numpy(np.asarray(sequence[2],dtype=int)).type(self.lnt_typ)\n",
    "        Label_t=torch.from_numpy(np.asarray(sequence[3],dtype=int)).type(self.flt_typ)\n",
    "        \n",
    "        if self.brt:\n",
    "            Bert_out=self.PreBERTmodel.bert(input_ids=token_t, attention_mask=torch.from_numpy(np.asarray(sequence[1],dtype=int)).type(self.lnt_typ),\n",
    "                                  token_type_ids=seg_t)\n",
    "            output, hidden = self.rnn_c(Bert_out[0])#,h_0) \n",
    "        else:\n",
    "            embeddings = self.embed(token_t)+self.vembed(seg_t)\n",
    "            #embeddings = self.LayerNorm(embeddings)\n",
    "            #in1= self.dropout(embeddings)\n",
    "            in1=embeddings\n",
    "            output, hidden = self.rnn_c(in1)#,h_0) \n",
    "        \n",
    "        if self.cell_type == \"LSTM\":\n",
    "            hidden=hidden[0]\n",
    "        if self.bi==2:\n",
    "            output = self.sigmoid(self.out(torch.cat((hidden[-2],hidden[-1]),1)))\n",
    "        else:\n",
    "            output = self.sigmoid(self.out(hidden[-1]))\n",
    "        return output.squeeze(),Label_t\n",
    "\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Print Predictions function\n",
    "def print_preds(model, mbs_list, shuffle = True): \n",
    "    model.eval() \n",
    "    y_real =[]\n",
    "    y_hat= []\n",
    "    pts_sk=[]\n",
    "    if shuffle: \n",
    "        random.shuffle(mbs_list)\n",
    "    for i,batch in enumerate(mbs_list):\n",
    "        pt_skl=batch[1]\n",
    "        output,label_tensor = model(batch[0])\n",
    "        y_hat.extend(output.cpu().data.view(-1).numpy())  \n",
    "        y_real.extend(label_tensor.cpu().data.view(-1).numpy())\n",
    "        pts_sk.extend(pt_skl)\n",
    "    \n",
    "    auc = roc_auc_score(y_real, y_hat)\n",
    "    return auc, pts_sk,y_real, y_hat "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pmFYvkylMwXn"
   },
   "source": [
    "#### Load Data from pickled list\n",
    "\n",
    "The pickled list is a list of lists where each sublist represent a patient record that looks like \n",
    "[pt_id,label, seq_list , segment_list ]\n",
    "where\n",
    "    Label: 1: pt developed HF (case) , 0 control\n",
    "    seq_list: list of all medical codes in all visits\n",
    "    segment list: the visit number mapping to each code in the sequence list\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.train', 'rb'), encoding='bytes')\n",
    "valid_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.valid', 'rb'), encoding='bytes')\n",
    "test_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.test', 'rb'), encoding='bytes')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20000 5000 2500\n"
     ]
    }
   ],
   "source": [
    "print (len(train_f),len(test_f),len(valid_f))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Print Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " creating the list of training minibatches\n",
      " creating the list of test minibatches\n",
      " creating the list of valid minibatches\n"
     ]
    }
   ],
   "source": [
    "MAX_SEQ_LENGTH = 64\n",
    "BATCH_SIZE = 100\n",
    "bert_config_file= \"sc3_config.json\"\n",
    "\n",
    "results=[]\n",
    "\n",
    "loaded_model= torch.load('PC_BIGRU_BRT.pth')\n",
    "loaded_model.load_state_dict(torch.load('PC_BIGRU_BRT.st'))\n",
    "loaded_model.eval()\n",
    "\n",
    "train_features = convert_EHRexamples_to_features(train_f, MAX_SEQ_LENGTH) \n",
    "test_features = convert_EHRexamples_to_features(test_f, MAX_SEQ_LENGTH)\n",
    "valid_features = convert_EHRexamples_to_features(valid_f, MAX_SEQ_LENGTH)\n",
    "\n",
    "train = BERTdataEHR(train_features)\n",
    "test = BERTdataEHR(test_features)\n",
    "valid = BERTdataEHR(valid_features)\n",
    "           \n",
    "print (' creating the list of training minibatches')\n",
    "train_mbs = list(BERTdataEHRloader(train, batch_size = BATCH_SIZE))\n",
    "print (' creating the list of test minibatches')\n",
    "test_mbs = list(BERTdataEHRloader(test, batch_size = BATCH_SIZE))\n",
    "\n",
    "print (' creating the list of valid minibatches')\n",
    "valid_mbs = list(BERTdataEHRloader(valid, batch_size = BATCH_SIZE))\n",
    "\n",
    "\n",
    "auc_tr, pts_sk_tr,y_real_tr, y_hat_tr=print_preds(loaded_model, train_mbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9018740796391604"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "auc_tr   #0.9018740796391604"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_df=pd.DataFrame({'Pt_sk':pts_sk_tr,'Label':y_real_tr, 'Prediction':y_hat_tr})\n",
    "pred_df[(pred_df['Prediction']>0.9) & (pred_df['Label']==1)]"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "From the output of the above cell we found:\n",
    "\n",
    "True Positive List:\n",
    "    pt1 \t1.0 0.987743\n",
    "    pt2 \t1.0 \t0.991851\n",
    "    pt3  \t1.0 \t0.987545\n",
    "    pt4  \t1.0 \t0.991477\n",
    "    pt5  \t1.0 \t0.952177\n",
    "    \n",
    "False Positive List:\n",
    "    pt6  \t0.0 \t0.968320\n",
    "    pt7  \t0.0 \t0.958237\n",
    "    pt8  \t0.0 \t0.976596\n",
    "    pt9  \t0.0 \t0.976421\n",
    "    pt10  \t0.0 \t0.965359"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_f=pkl.load( open('pdata/lr_pc_cid_btexp.combined_BertFT.train', 'rb'), encoding='bytes')\n",
    "types_d=pkl.load(open('pdata/lr_pc_cid_btexp.types.valid', 'rb'), encoding='bytes')\n",
    "types_d_rev = dict(zip(types_d.values(),types_d.keys()))\n",
    "diag_det= pd.read_table('/data/LR_test/panc/data/HF_D_DIAGNOSIS',sep='|')\n",
    "for x in train_f:\n",
    "    if x[0] in [ pt1 ,pt2 ,pt3 ,pt4 ,pt5 ,pt6 ,pt7 ,pt8 ,pt9 ,pt10 ]:\n",
    "        print (x)\n",
    "        for i in x[2]:\n",
    "              print(diag_det[['DIAGNOSIS_CODE','DIAGNOSIS_DESCRIPTION']][diag_det['DIAGNOSIS_ID']==types_d_rev[i]])  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the output of the above cell we found:\n",
    "\n",
    "Codes that are the key of high score are:\n",
    "    577.2 CYST AND PSEUDOCYST OF PANCREAS\n",
    "    \n",
    "    576.2  OBSTRUCTION OF BILE DUCT\n",
    "    783.21        LOSS OF WEIGHT\n",
    "    441.4  ABDOMINAL AORTIC ANEURYSM \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Predicting Movie Reviews with BERT on TF Hub.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "py_37_env",
   "language": "python",
   "name": "py_37_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
